"""
工具注册系统
提供统一的工具注册和管理接口
"""

from typing import List, Dict, Any, Optional, Callable, Awaitable
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolParamType

logger = get_logger("memory_retrieval_tools")


class MemoryRetrievalTool:
    """记忆检索工具基类"""

    def __init__(
        self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
    ):
        """
        初始化工具

        Args:
            name: 工具名称
            description: 工具描述
            parameters: 参数定义列表，格式：[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
            execute_func: 执行函数，必须是异步函数
        """
        self.name = name
        self.description = description
        self.parameters = parameters
        self.execute_func = execute_func

    def get_tool_description(self) -> str:
        """获取工具的文本描述，用于prompt"""
        param_descriptions = []
        for param in self.parameters:
            param_name = param.get("name", "")
            param_type = param.get("type", "string")
            param_desc = param.get("description", "")
            required = param.get("required", True)
            required_str = "必填" if required else "可选"
            param_descriptions.append(f"   - {param_name} ({param_type}, {required_str}): {param_desc}")

        params_str = "\n".join(param_descriptions) if param_descriptions else "   无参数"
        return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"

    async def execute(self, **kwargs) -> str:
        """执行工具"""
        return await self.execute_func(**kwargs)
    
    def get_tool_definition(self) -> Dict[str, Any]:
        """获取工具定义，用于LLM function calling
        
        Returns:
            Dict[str, Any]: 工具定义字典，格式与BaseTool一致
            格式: {"name": str, "description": str, "parameters": List[Tuple]}
        """
        # 转换参数格式为元组列表，格式与BaseTool一致
        # 格式: [("param_name", ToolParamType, "description", required, enum_values)]
        param_tuples = []
        
        for param in self.parameters:
            param_name = param.get("name", "")
            param_type_str = param.get("type", "string").lower()
            param_desc = param.get("description", "")
            is_required = param.get("required", False)
            enum_values = param.get("enum", None)
            
            # 转换类型字符串到ToolParamType
            type_mapping = {
                "string": ToolParamType.STRING,
                "integer": ToolParamType.INTEGER,
                "int": ToolParamType.INTEGER,
                "float": ToolParamType.FLOAT,
                "boolean": ToolParamType.BOOLEAN,
                "bool": ToolParamType.BOOLEAN,
            }
            param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
            
            # 构建参数元组
            param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
            param_tuples.append(param_tuple)
        
        # 构建工具定义，格式与BaseTool.get_tool_definition()一致
        tool_def = {
            "name": self.name,
            "description": self.description,
            "parameters": param_tuples
        }
        
        return tool_def


class MemoryRetrievalToolRegistry:
    """工具注册器"""

    def __init__(self):
        self.tools: Dict[str, MemoryRetrievalTool] = {}

    def register_tool(self, tool: MemoryRetrievalTool) -> None:
        """注册工具"""
        if tool.name in self.tools:
            logger.debug(f"记忆检索工具 {tool.name} 已存在，跳过重复注册")
            return
        self.tools[tool.name] = tool
        logger.info(f"注册记忆检索工具: {tool.name}")

    def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
        """获取工具"""
        return self.tools.get(name)

    def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
        """获取所有工具"""
        return self.tools.copy()

    def get_tools_description(self) -> str:
        """获取所有工具的描述，用于prompt"""
        descriptions = []
        for i, tool in enumerate(self.tools.values(), 1):
            descriptions.append(f"{i}. {tool.get_tool_description()}")
        return "\n".join(descriptions)

    def get_action_types_list(self) -> str:
        """获取所有动作类型的列表，用于prompt（已废弃，保留用于兼容）"""
        action_types = [tool.name for tool in self.tools.values()]
        action_types.append("final_answer")
        action_types.append("no_answer")
        return " 或 ".join([f'"{at}"' for at in action_types])
    
    def get_tool_definitions(self) -> List[Dict[str, Any]]:
        """获取所有工具的定义列表，用于LLM function calling
        
        Returns:
            List[Dict[str, Any]]: 工具定义列表，每个元素是一个工具定义字典
        """
        return [tool.get_tool_definition() for tool in self.tools.values()]


# 全局工具注册器实例
_tool_registry = MemoryRetrievalToolRegistry()


def register_memory_retrieval_tool(
    name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
) -> None:
    """注册记忆检索工具的便捷函数

    Args:
        name: 工具名称
        description: 工具描述
        parameters: 参数定义列表
        execute_func: 执行函数
    """
    tool = MemoryRetrievalTool(name, description, parameters, execute_func)
    _tool_registry.register_tool(tool)


def get_tool_registry() -> MemoryRetrievalToolRegistry:
    """获取工具注册器实例"""
    return _tool_registry
